import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import sys
import os

from advtr.data import train_loader, test_loader
from advtr.attacks import fgsm, rfgsm, pgd, rpgd, pgd2, rpgd2, pgd_linf, rpgd_linf
from advtr.train import epoch, epoch_adversarial
from advtr.model import model_gen


net, eps, alpha, seed = sys.argv[1], float(sys.argv[2]), float(sys.argv[3]), int(sys.argv[4])

torch.manual_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
file_pgd = "./results/models/model_atk=pgd_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
file_rpgd = "./results/models/model_atk=rpgd_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
file_pgd2 = "./results/models/model_atk=pgd2_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
file_rpgd2 = "./results/models/model_atk=rpgd2_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
stats_pgd = "./results/stats/stats_atk=pgd_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
stats_rpgd = "./results/stats/stats_atk=rpgd_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
stats_pgd2 = "./results/stats/stats_atk=pgd2_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
stats_rpgd2 = "./results/stats/stats_atk=rpgd2_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
if not os.path.exists('./results/models'):
    os.makedirs('./results/models')
if not os.path.exists('./results/stats'):
    os.makedirs('./results/stats')

model_pgd = model_gen(net).to(device)
model_rpgd = model_gen(net).to(device)
model_pgd2 = model_gen(net).to(device)
model_rpgd2 = model_gen(net).to(device)

opt_pgd = optim.SGD(model_pgd.parameters(), lr=1)
opt_rpgd = optim.SGD(model_rpgd.parameters(), lr=1)
opt_pgd2 = optim.SGD(model_pgd2.parameters(), lr=1)
opt_rpgd2 = optim.SGD(model_rpgd2.parameters(), lr=1)

pgd_stat, rpgd_stat, pgd2_stat, rpgd2_stat = [], [], [], []
for t in range(20):
    if not os.path.exists(file_pgd):
        train_err_pgd, train_loss_pgd = epoch_adversarial(train_loader, model_pgd, pgd, opt_pgd, epsilon=eps, alpha=alpha)
        test_err_pgd, test_loss_pgd = epoch(test_loader, model_pgd)
        adv_err_pgd, adv_loss_pgd = epoch_adversarial(test_loader, model_pgd, pgd, epsilon=eps, alpha=alpha)
        pgd_stat.append([train_err_pgd, train_loss_pgd, test_err_pgd, test_loss_pgd, adv_err_pgd, adv_loss_pgd])

    if not os.path.exists(file_rpgd):
        train_err_rpgd, train_loss_rpgd = epoch_adversarial(train_loader, model_rpgd, rpgd, opt_rpgd, epsilon=eps, alpha=alpha)
        test_err_rpgd, test_loss_rpgd = epoch(test_loader, model_rpgd)
        adv_err_rpgd, adv_loss_rpgd = epoch_adversarial(test_loader, model_rpgd, pgd, epsilon=eps, alpha=alpha)
        rpgd_stat.append([train_err_rpgd, train_loss_rpgd, test_err_rpgd, test_loss_rpgd, adv_err_rpgd, adv_loss_rpgd])

    if not os.path.exists(file_pgd2):
        train_err_pgd2, train_loss_pgd2 = epoch_adversarial(train_loader, model_pgd2, pgd2, opt_pgd2, epsilon=eps, alpha=alpha)
        test_err_pgd2, test_loss_pgd2 = epoch(test_loader, model_pgd2)
        adv_err_pgd2, adv_loss_pgd2 = epoch_adversarial(test_loader, model_pgd2, pgd2, epsilon=eps, alpha=alpha)
        pgd2_stat.append([train_err_pgd2, train_loss_pgd2, test_err_pgd2, test_loss_pgd2, adv_err_pgd2, adv_loss_pgd2])

    if not os.path.exists(file_rpgd2):
        train_err_rpgd2, train_loss_rpgd2 = epoch_adversarial(train_loader, model_rpgd2, rpgd2, opt_rpgd2, epsilon=eps, alpha=alpha)
        test_err_rpgd2, test_loss_rpgd2 = epoch(test_loader, model_rpgd2)
        adv_err_rpgd2, adv_loss_rpgd2 = epoch_adversarial(test_loader, model_rpgd2, pgd2, epsilon=eps, alpha=alpha)
        rpgd2_stat.append([train_err_rpgd2, train_loss_rpgd2, test_err_rpgd2, test_loss_rpgd2, adv_err_rpgd2, adv_loss_rpgd2])
    
    if t % 5 == 4:
        for param_group in opt_pgd.param_groups: param_group["lr"] /= 5
        for param_group in opt_rpgd.param_groups: param_group["lr"] /= 5
        for param_group in opt_pgd2.param_groups: param_group["lr"] /= 5
        for param_group in opt_rpgd2.param_groups: param_group["lr"] /= 5

    print(*("{:.6f}".format(i) for i in (train_err_pgd, test_err_pgd, adv_err_pgd, train_err_rpgd, test_err_rpgd, adv_err_rpgd)), sep="\t")
    print(*("{:.6f}".format(i) for i in (train_err_pgd2, test_err_pgd2, adv_err_pgd2, train_err_rpgd2, test_err_rpgd2, adv_err_rpgd2)), sep="\t")
    print('\n')

if not os.path.exists(file_pgd):
    torch.save(pgd_stat, stats_pgd)
    torch.save(model_pgd.state_dict(), file_pgd)

if not os.path.exists(file_rpgd):
    torch.save(rpgd_stat, stats_rpgd)
    torch.save(model_rpgd.state_dict(), file_rpgd)

if not os.path.exists(file_pgd2):
    torch.save(pgd2_stat, stats_pgd2)
    torch.save(model_pgd2.state_dict(), file_pgd2)

if not os.path.exists(file_rpgd2):
    torch.save(rpgd2_stat, stats_rpgd2)
    torch.save(model_rpgd2.state_dict(), file_rpgd2)